/*
 * Copyright(C) 2024. Huawei Technologies Co.,Ltd. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include <numeric>
#include <cmath>
#include <random>
#include <gtest/gtest.h>
#include <faiss/ascend/AscendIndexVStar.h>
#include <cstring>
#include <sys/time.h>
#include <cstdlib>
#include <memory>

using namespace faiss::ascend;
namespace {
inline double GetMillisecs()
{
    struct timeval tv = {0, 0};
    gettimeofday(&tv, nullptr);
    return tv.tv_sec * 1e3 + tv.tv_usec * 1e-3;
}

inline void Generate(size_t ntotal, std::vector<float> &data, int seed = 5678)
{
    std::default_random_engine e(seed);
    std::uniform_real_distribution<float> rCode(0.0f, 1.0f);
    data.resize(ntotal);
    for (size_t i = 0; i < ntotal; ++i) {
        data[i] = static_cast<float>(255 * rCode(e) - 128);
    }
}

inline void Norm(std::vector<float> &data, int dim)
{
    float square = 0.0;
    int nTotal = (dim == 0) ? 0 : static_cast<int>(data.size() / dim);
    for (int i = 0; i < nTotal; ++i) {
        square = 0.0;
        for (int j = 0; j < dim; ++j) {
            square += pow(data[i * dim + j], 2);
        }
        square = sqrt(square);
        for (int j = 0; j < dim; ++j) {
            data[static_cast<size_t>(i) * dim + j] /= square;
        }
    }
}

/**
 * AKMode需要提前生成算子和码本
 * 码本和算子参数根据实际情况调整, dim nlistL1 subDimL1 要与创建的索引一致
 * 算子：python3 vstar_generate_models.py --dim 1024 --nlist1 1024 --subDimL1 32
 * 码本：python3 vstar_train_codebook.py --dataPath {实际base数据路径} --dim 1024 --codebookPath {实际码本输出路径}
 --nListL1 1024 --subDimL1 32 --device 0
 */
TEST(TestAscendIndexVstar, Test_Search_Func)
{
    int dim = 1024;
    size_t ntotal = 1e5;
    int nlist = 256;
    int subSpaceDim = 128;
    std::vector<int> devices = {0};
    bool verbose = false;
    AscendIndexVstarInitParams aParams(dim, subSpaceDim, nlist, devices, verbose);
    auto index = std::make_shared<AscendIndexVStar>(aParams);

    // 添加码本 需要提前生成好码本路径
    std::string codebook = "/home/work/codebook_256_1024_128/codebook_l1_l2.bin";
    auto ret = index->AddCodeBooksByPath(codebook);
    EXPECT_EQ(ret, 0);

    // 生成base底库数据
    std::vector<float> data(ntotal);
    Generate(ntotal * dim, data);
    // 标准化
    Norm(data, dim);

    // add底库
    index->Add(data);
    size_t total = 0;
    index->GetNTotal(total);
    EXPECT_EQ(total, ntotal);
    
    // search检索
    int topk = 100;
    int warmUpTimes = 10;
    size_t nq = 9000;
    std::vector<float> distsWarm(nq * topk);
    std::vector<int64_t> labelsWarm(nq * topk);

    // warm up
    for (int i = 0; i < warmUpTimes; ++i) {
        AscendIndexSearchParams searchParamsWarm {100, data, topk, distsWarm, labelsWarm};
        index->Search(searchParamsWarm);
    }

    // search
    std::vector<size_t> searchNum = {1, 8, 16, 32, 64, 128, 256};
    int loopTimes = 100;
    for (auto n : searchNum) {
        std::vector<float> queryData(data.begin(), data.begin() + n * dim);
        std::vector<float> dists(n * topk, 0);
        std::vector<int64_t> labels(n * topk, 0);
        double ts = GetMillisecs();
        for (int i = 0; i < loopTimes; ++i) {
            AscendIndexSearchParams searchParams {n, queryData, topk, dists, labels};
            index->Search(searchParams);
        }
        double te = GetMillisecs();
        printf("base:%zu, dim:%d, search num:%zu, QPS:%.4f\n", ntotal, dim, n, 1000 * n * loopTimes / (te - ts));
    }
}
} // namespace

int main(int argc, char **argv)
{
    testing::InitGoogleTest(&argc, argv);

    return RUN_ALL_TESTS();
}